package com.morty.openapi;

import com.alibaba.fastjson.JSON;
import com.morty.log.GatewayLog;
import com.morty.log.LogHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.filter.factory.rewrite.CachedBodyOutputMessage;
import org.springframework.cloud.gateway.route.Route;
import org.springframework.cloud.gateway.support.BodyInserterContext;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ReactiveHttpOutputMessage;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.stereotype.Component;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserter;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;
import java.util.Date;
import java.util.List;
import java.util.Map;

@Component
@Slf4j
public class OpenApiLogGatewayFilterFactory extends AbstractGatewayFilterFactory<OpenApiLogGatewayFilterFactory.Config> {

    private final List<HttpMessageReader<?>> messageReaders = HandlerStrategies.withDefaults().messageReaders();

    public OpenApiLogGatewayFilterFactory() {
        super(Config.class);
    }

    @Override
    public GatewayFilter apply(Config config) {
        return new OpenApiResponseLogGatewayFilter();
    }


    public static class Config {

    }

    public class OpenApiResponseLogGatewayFilter implements GatewayFilter, Ordered {

        @Override
        public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {

            ServerHttpRequest request = exchange.getRequest();

            GatewayLog gatewayLog = assembleLog(exchange, request);

            MediaType mediaType = request.getHeaders().getContentType();

            if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(mediaType) || MediaType.APPLICATION_JSON.isCompatibleWith(mediaType)) {
                return writeBodyLog(exchange, chain, gatewayLog);
            } else {
                return writeBasicLog(exchange, chain, gatewayLog);
            }
        }

        private GatewayLog assembleLog(ServerWebExchange exchange, ServerHttpRequest request) {
            // 请求路径
            String requestPath = request.getPath().pathWithinApplication().value();
            Route route = getGatewayRoute(exchange);
            String ipAddress = LogHelper.getRealIpAddress(request);

            GatewayLog gatewayLog = new GatewayLog();
            gatewayLog.setSchema(request.getURI().getScheme());
            gatewayLog.setRequestMethod(request.getMethodValue());
            gatewayLog.setRequestPath(requestPath);
            gatewayLog.setTargetServer(route.getId());
            gatewayLog.setRequestTime(new Date());
            gatewayLog.setIp(ipAddress);
            return gatewayLog;
        }

        private Mono<Void> writeBasicLog(ServerWebExchange exchange, GatewayFilterChain chain, GatewayLog accessLog) {
            StringBuilder builder = new StringBuilder();
            MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
            for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {
                builder.append(entry.getKey()).append("=").append(StringUtils.join(entry.getValue(), ","));
            }
            accessLog.setRequestBody(builder.toString());

            //获取响应体
            ServerHttpResponseDecorator decoratedResponse = recordResponseLog(exchange, accessLog);

            return chain.filter(exchange.mutate().response(decoratedResponse).build())
                    .then(Mono.fromRunnable(() -> {
                        // 打印日志
                        writeAccessLog(accessLog);
                    }));
        }


        /**
         * 解决 request body 只能读取一次问题，
         * 参考: org.springframework.cloud.gateway.filter.factory.rewrite.ModifyRequestBodyGatewayFilterFactory
         */
        private Mono<Void> writeBodyLog(ServerWebExchange exchange, GatewayFilterChain chain, GatewayLog gatewayLog) {
            ServerRequest serverRequest = ServerRequest.create(exchange, messageReaders);

            Mono<String> modifiedBody = serverRequest.bodyToMono(String.class)
                    .flatMap(body -> {
                        gatewayLog.setRequestBody(body);
                        log.info("open-api-request:{}", JSON.toJSONString(gatewayLog));
                        return Mono.just(body);
                    });

            // 通过 BodyInserter 插入 body(支持修改body), 避免 request body 只能获取一次
            BodyInserter<Mono<String>, ReactiveHttpOutputMessage> bodyInserter = BodyInserters.fromPublisher(modifiedBody, String.class);
            HttpHeaders headers = new HttpHeaders();
            headers.putAll(exchange.getRequest().getHeaders());
            // the new content type will be computed by bodyInserter
            // and then set in the request decorator
            headers.remove(HttpHeaders.CONTENT_LENGTH);

            CachedBodyOutputMessage outputMessage = new CachedBodyOutputMessage(exchange, headers);

            return bodyInserter.insert(outputMessage, new BodyInserterContext())
                    .then(Mono.defer(() -> {
                        // 重新封装请求
                        ServerHttpRequest decoratedRequest = requestDecorate(exchange, headers, outputMessage);

                        // 记录响应日志
                        ServerHttpResponseDecorator decoratedResponse = recordResponseLog(exchange, gatewayLog);

                        // 记录普通的
                        return chain.filter(exchange.mutate().request(decoratedRequest).response(decoratedResponse).build())
                                .then(Mono.fromRunnable(() -> {
                                    // 打印日志
                                    writeAccessLog(gatewayLog);
                                }));
                    }));
        }

        /**
         * 打印日志
         */
        private void writeAccessLog(GatewayLog gatewayLog) {
            log.info("open-api-end:{}", JSON.toJSONString(gatewayLog));
        }


        private Route getGatewayRoute(ServerWebExchange exchange) {
            return exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
        }


        /**
         * 请求装饰器，重新计算 headers
         */
        private ServerHttpRequestDecorator requestDecorate(ServerWebExchange exchange, HttpHeaders headers,
                                                           CachedBodyOutputMessage outputMessage) {
            return new ServerHttpRequestDecorator(exchange.getRequest()) {
                @Override
                public HttpHeaders getHeaders() {
                    long contentLength = headers.getContentLength();
                    HttpHeaders httpHeaders = new HttpHeaders();
                    httpHeaders.putAll(super.getHeaders());
                    if (contentLength > 0) {
                        httpHeaders.setContentLength(contentLength);
                    } else {
                        //TODO: this causes a 'HTTP/1.1 411 Length Required' // on
                        // httpbin.org
                        httpHeaders.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                    }
                    return httpHeaders;
                }

                @Override
                public Flux<DataBuffer> getBody() {
                    return outputMessage.getBody();
                }
            };
        }


        /**
         * 记录响应日志
         * 通过 DataBufferFactory 解决响应体分段传输问题。
         */
        private ServerHttpResponseDecorator recordResponseLog(ServerWebExchange exchange, GatewayLog gatewayLog) {
            ServerHttpResponse response = exchange.getResponse();
            DataBufferFactory bufferFactory = response.bufferFactory();

            return new ServerHttpResponseDecorator(response) {
                @Override
                public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
                    if (body instanceof Flux) {
                        Date responseTime = new Date();
                        gatewayLog.setResponseTime(responseTime);
                        // 计算执行时间
                        long executeTime = (responseTime.getTime() - gatewayLog.getRequestTime().getTime());

                        gatewayLog.setExecuteTime(executeTime);

                        // 获取响应类型，如果是 json 就打印
                        String originalResponseContentType = exchange.getAttribute(ServerWebExchangeUtils.ORIGINAL_RESPONSE_CONTENT_TYPE_ATTR);

                        if (StringUtils.isNotBlank(originalResponseContentType)
                                && originalResponseContentType.contains("application/json")) {

                            Flux<? extends DataBuffer> fluxBody = Flux.from(body);
                            return super.writeWith(fluxBody.buffer().map(dataBuffers -> {

                                // 合并多个流集合，解决返回体分段传输
                                DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory();
                                DataBuffer join = dataBufferFactory.join(dataBuffers);
                                byte[] content = new byte[join.readableByteCount()];
                                join.read(content);

                                // 释放掉内存
                                DataBufferUtils.release(join);
                                String responseResult = new String(content, StandardCharsets.UTF_8);

                                gatewayLog.setResponseData(responseResult);
                                log.info("open-api-response:{}", JSON.toJSONString(gatewayLog));
                                return bufferFactory.wrap(content);
                            }));
                        }
                    }
                    // if body is not a flux. never got there.
                    return super.writeWith(body);
                }
            };
        }

        @Override
        public int getOrder() {
            return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 1;
        }
    }
}
